/* Include files */
#include "math.h"
#include "mex.h"

/* Include to sort : http://stackoverflow.com/questions/1787996/c-library-function-to-do-sort */
/*
#include <stdio.h>
#include <stdlib.h>
*/

/*************************************************/
/* Delete Zero Vec */
/*
static int AssignNonZeroVec(int *rowIn, int *rowOut, int A)
{
	int a = 0; int b = 0;
	for (a = 0; a < A; a++)
	{
		if (rowIn[a] > 0) 
		{
			rowOut[b] = rowIn[a]; 
			b++;
			printf("%d b \n",b);
		}
		printf("%d A \n",a);
	}
	return b;
}
*/

/* Checks whether inputs are matrices */
#define IS_REAL_2D_FULL_DOUBLE(P) (!mxIsComplex(P) && \
mxGetNumberOfDimensions(P) == 2 && !mxIsSparse(P) && mxIsDouble(P))
#define IS_REAL_SCALAR(P) (IS_REAL_2D_FULL_DOUBLE(P) && mxGetNumberOfElements(P) == 1)

/*************************************************/
/* MAIN                                          */
/*************************************************/

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	/* Define output matrix */
	#define GpAShare	plhs[0]
	#define GpBShare	plhs[1]

	#define GpMarket	plhs[2]
	#define GpFirm		plhs[3]
	#define GpFMean 	plhs[4]

	#define GpACov 	  	plhs[5]
	#define GpBCov   	plhs[6]

	#define GpAMean 	plhs[7]
	#define GpBMean   	plhs[8]


	/* Define input matrices */
	#define Amat 		prhs[0]
	#define Bmat    	prhs[1]

	#define FirmMatch 	prhs[2]
	#define Jm          prhs[3]
	#define GroupIndex  prhs[4]
	#define GroupCount  prhs[5]

	#define Share		prhs[6]

	/* Check the number of inputs is correct (Aoffer,Baccept,QbarA) */
	if(nrhs < 7)
		mexErrMsgTxt("Too few input arguments.");
	else if(nrhs > 7)
		mexErrMsgTxt("Too many input arguments.");

	if(!IS_REAL_2D_FULL_DOUBLE(Amat))
		mexErrMsgTxt("Amat must be real 2D full double array.");
	if(!IS_REAL_2D_FULL_DOUBLE(Bmat))
		mexErrMsgTxt("Bmat must be real 2D full double array.");

	if(!IS_REAL_2D_FULL_DOUBLE(FirmMatch))
		mexErrMsgTxt("FirmMatch must be real 2D full double array.");
	if(!IS_REAL_2D_FULL_DOUBLE(Jm))
		mexErrMsgTxt("Jm must be real 2D full double array.");
	if(!IS_REAL_2D_FULL_DOUBLE(GroupIndex))
		mexErrMsgTxt("GroupIndex must be real 2D full double array.");
	if(!IS_REAL_2D_FULL_DOUBLE(GroupCount))
		mexErrMsgTxt("GroupCount must be real 2D full double array.");
	if(!IS_REAL_2D_FULL_DOUBLE(Share))
		mexErrMsgTxt("Share must be real 2D full double array.");

	/* Dimensions and pointers for inputs */

	int n = mxGetN(FirmMatch);
	int M = mxGetN(GroupIndex);
	int Maxlen = mxGetM(GroupIndex);
 
	int A = mxGetN(Amat);
	int B = mxGetN(Bmat);

	double *A_pt 		= mxGetPr(Amat);
	double *B_pt 		= mxGetPr(Bmat);
	double *S_pt 		= mxGetPr(Share);

	double *Firm_pt 	= mxGetPr(FirmMatch);
	double *Jm_pt 		= mxGetPr(Jm);
	double *GpIn_pt		= mxGetPr(GroupIndex);
	double *GpCt_pt		= mxGetPr(GroupCount);

	/* Output matrix */

	/* Statistic containers */

	int J  = Jm_pt[0];
	int JM = Jm_pt[0]*M;

	GpAMean = mxCreateDoubleMatrix(JM,A,mxREAL);
	double *GpAMean_pt = mxGetPr(GpAMean);

	GpBMean = mxCreateDoubleMatrix(JM,B,mxREAL);
	double *GpBMean_pt = mxGetPr(GpBMean);

	GpACov = mxCreateDoubleMatrix(M,A,mxREAL);
	double *GpACov_pt = mxGetPr(GpACov);

	GpBCov = mxCreateDoubleMatrix(M,B,mxREAL);
	double *GpBCov_pt = mxGetPr(GpBCov);


	GpAShare = mxCreateDoubleMatrix(M,A,mxREAL);
	double *GpAShare_pt = mxGetPr(GpAShare);

	GpBShare = mxCreateDoubleMatrix(M,B,mxREAL);
	double *GpBShare_pt = mxGetPr(GpBShare);

	GpMarket = mxCreateDoubleMatrix(M,1,mxREAL);
	double *GpMarket_pt = mxGetPr(GpMarket);

	GpFirm = mxCreateDoubleMatrix(M,1,mxREAL);
	double *GpFirm_pt = mxGetPr(GpFirm);
	
	GpFMean = mxCreateDoubleMatrix(J,1,mxREAL);
	double *GpFMean_pt = mxGetPr(GpFMean);
	
	/* Temporary containers */

	double *MShare 	= calloc(JM, sizeof (double));
	int *rowVec 	= calloc(Maxlen, sizeof (int));
	int *firmCt 	= calloc(J, sizeof (int));

	/* ---------------------------------------------------------------------------- */
	/* MAIN                                                                         */
	/* ---------------------------------------------------------------------------- */

	int a, b, m, i, j;
	int len, marLen, index, marIndex, loc, locJ, totFirm;
	double temp1, temp2;

	int begin_pt = 0;

	/* ---------------------------------------------------------------------------- */
	/* Share and attributes joint distribution     									*/
	/* ---------------------------------------------------------------------------- */

	for (m = 0; m < M; m++)
	{
		/* ----------------------------------- */
		/* Initialize to zero */
		
		for (j = 0; j < J; j++)
			firmCt[j] = 0;

		/* ----------------------------------- */
		/* Delete zeros from MarketIndex vector */

		marLen = 0;
		for (len = 0; len < Maxlen; len++)
		{
			if (GpIn_pt[len + Maxlen*m] != 0)
			{
				rowVec[len] = GpIn_pt[len + Maxlen*m];
				marLen++;
			}
		}

		if (GpCt_pt[m + M*0] > 1)
		{
			/* ----------------------------------- */
			/* Calculate Group Moments */

			/* Count the number of leases signed in a market m by firm j */
			for (len = 0; len < marLen; len++)						/* Loop through all observations in market m */
			{
				marIndex 	= rowVec[len] - 1;
				loc 		= Firm_pt[marIndex] - 1;				/* Firm ID */
				if (loc > -1)
					firmCt[loc] = firmCt[loc] + 1;
			}

			/* Count the number of active firms in market m */
			totFirm = 0;
			for (j = 0; j < J; j++)
			{
				if (firmCt[j] > 0)
					totFirm++;
			}

			/* Firm-level mean for a market */
			for (len = 0; len < marLen; len++)						/* Loop through all observations in market m */
			{
				marIndex 	= rowVec[len] - 1;
				loc 		= Firm_pt[marIndex] - 1;				/* Firm ID */
				if (loc > -1 & firmCt[loc] > 0)
				{
					for (a = 0; a < A; a++)
						GpAMean_pt[J*m + loc + JM*a] = GpAMean_pt[J*m + loc + JM*a] + A_pt[marIndex + n*a]/(float)firmCt[loc];
					for (b = 0; b < B; b++)
						GpBMean_pt[J*m + loc + JM*b] = GpBMean_pt[J*m + loc + JM*b] + B_pt[marIndex + n*b]/(float)firmCt[loc];
				}
			}

			/* Firm-level mean for a market */
			for (j = 0; j < J; j++)
			{
				MShare[m] 		= MShare[m] 	 + S_pt[m*J + j]/(float)(totFirm);
				GpMarket_pt[m] 	= GpMarket_pt[m] + S_pt[m*J + j]*(marLen/(float)totFirm);
				GpFirm_pt[m] 	= GpFirm_pt[m]   + S_pt[m*J + j]/(float)(totFirm);
				GpFMean_pt[j]   = GpFMean_pt[j]  + S_pt[M*j + m]/(float)(M);
			}

			for (j = 0; j < J; j++)
			{
				for (a = 0; a < A; a++)
				{
					GpAShare_pt[m + M*a] 	= GpAShare_pt[m + M*a] + S_pt[m*J + j]*GpAMean_pt[m*J + j + JM*a]/(float)totFirm;
					GpACov_pt[m + M*a] 		= GpACov_pt[m + M*a] + (S_pt[J*m + j] - MShare[J*m + j])*(GpAMean_pt[J*m + j + JM*a])/(float)totFirm;
				}
				for (b = 0; b < B; b++)
				{				
					GpBShare_pt[m + M*b] 	= GpBShare_pt[m + M*b] + S_pt[m*J + j]*GpBMean_pt[m*J + j + JM*b]/(float)totFirm;
					GpBCov_pt[m + M*b] 		= GpBCov_pt[m + M*b] + (S_pt[J*m + j] - MShare[J*m + j])*(GpBMean_pt[J*m + j + JM*b])/(float)totFirm;

/*					if (firmCt[j] > 0)
						GpBCov_pt[J*m + j + JM*b] 	= GpBCov_pt[J*m + j + JM*b] + (S_pt[J*m + j] - MShare[J*m + j])*(GpBMean_pt[J*m + j + JM*b])/(float)firmCt[loc]; */
				}
			}
		}
	}

	free(MShare);
	free(rowVec);
	free(firmCt);

}
